{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## LLM Flow using RAG"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain.prompts.chat import SystemMessagePromptTemplate\n",
    "from langchain.prompts.chat import HumanMessagePromptTemplate\n",
    "from langchain.prompts.chat import ChatPromptTemplate\n",
    "from langchain.embeddings import VertexAIEmbeddings\n",
    "from langchain.document_loaders import JSONLoader\n",
    "from langchain.embeddings.base import Embeddings\n",
    "from langchain.chat_models import ChatVertexAI\n",
    "from langchain.vectorstores import FAISS\n",
    "from google.cloud import bigquery\n",
    "from typing import List\n",
    "from tqdm import tqdm\n",
    "import logging\n",
    "import json\n",
    "import os "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Setup logging"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "logger = logging.getLogger('langchain')\n",
    "logger.setLevel(logging.INFO)\n",
    "logger.addHandler(logging.StreamHandler())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Setup essentials"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "SERVICE_ACCOUNT_KEY_PATH = './../credentials/vai-key.json'\n",
    "os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = SERVICE_ACCOUNT_KEY_PATH"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "PROJECT = 'arun-genai-bb'\n",
    "LOCATION = 'us-central1'\n",
    "MODEL_NAME = 'codechat-bison@latest'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "llm = ChatVertexAI(project=PROJECT, \n",
    "                   location=LOCATION, \n",
    "                   model_name=MODEL_NAME,\n",
    "                   temperature=0.0, \n",
    "                   max_output_tokens=512)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MyVertexAIEmbeddings(VertexAIEmbeddings, Embeddings):\n",
    "    model_name = 'textembedding-gecko'\n",
    "    max_batch_size = 5\n",
    "    \n",
    "    def embed_segments(self, segments: List) -> List:\n",
    "        embeddings = []\n",
    "        for i in tqdm(range(0, len(segments), self.max_batch_size)):\n",
    "            batch = segments[i: i+self.max_batch_size]\n",
    "            embeddings.extend(self.client.get_embeddings(batch))\n",
    "        return [embedding.values for embedding in embeddings]\n",
    "    \n",
    "    def embed_query(self, query: str) -> List:\n",
    "        embeddings = self.client.get_embeddings([query])\n",
    "        return embeddings[0].values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "embedding = MyVertexAIEmbeddings()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "query = \"Provide a list of all flight reservations from October 10th to October 15th, 2023\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Step 1: Embed and index tables info"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "documents = JSONLoader(file_path='./../data/rag-schema/tables.jsonl', jq_schema='.', text_content=False, json_lines=True).load()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "db = FAISS.from_documents(documents=documents, embedding=embedding)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Step 2: Match indexed tables embedding to incoming query"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "retriever = db.as_retriever(search_type='mmr', search_kwargs={'k': 5, 'lambda_mult': 1})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "matched_documents = retriever.get_relevant_documents(query=query)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Matched tables = ['flight_reservations.reservations', 'flight_reservations.transactions', 'flight_reservations.flights', 'hotel_reservations.reservations', 'hotel_reservations.inventory']\n"
     ]
    }
   ],
   "source": [
    "matched_tables = []\n",
    "\n",
    "for document in matched_documents:\n",
    "    page_content = document.page_content\n",
    "    page_content = json.loads(page_content)\n",
    "    dataset_name = page_content['dataset_name']\n",
    "    table_name = page_content['table_name']\n",
    "    matched_tables.append(f'{dataset_name}.{table_name}')\n",
    "\n",
    "logger.info(f'Matched tables = {matched_tables}')\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Step 3: Embed and index columns info\n",
    " "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "documents = JSONLoader(file_path='./../data/rag-schema/columns.jsonl', jq_schema='.', text_content=False, json_lines=True).load()\n",
    "db = FAISS.from_documents(documents=documents, embedding=embedding)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Document(page_content='{\"dataset_name\": \"hotel_reservations\", \"table_name\": \"hotels\", \"column_name\": \"hotel_id\", \"description\": \"A unique identifier assigned to each hotel.\", \"usage\": \"This ID helps in maintaining a distinct record for each hotel and acts as a primary key. It\\'s also used for referencing in other tables like Rooms.\", \"data_type\": \"INT64\"}', metadata={'source': '/Users/arunpshankar/Desktop/Projects/LLM-Text-to-SQL-Architectures/data/rag-schema/columns.jsonl', 'seq_num': 1})"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "documents[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Step 4: Match indexed columns embedding to incoming query"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "search_kwargs = {\n",
    "    'k': 20\n",
    "}\n",
    "\n",
    "retriever = db.as_retriever(search_type='similarity', search_kwargs=search_kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Document(page_content='{\"dataset_name\": \"flight_reservations\", \"table_name\": \"reservations\", \"column_name\": \"reservation_datetime\", \"description\": \"Timestamp of when the reservation was made.\", \"usage\": \"Helps track reservation history and manage bookings.\", \"data_type\": \"DATETIME\"}', metadata={'source': '/Users/arunpshankar/Desktop/Projects/LLM-Text-to-SQL-Architectures/data/rag-schema/columns.jsonl', 'seq_num': 62}), Document(page_content='{\"dataset_name\": \"flight_reservations\", \"table_name\": \"flights\", \"column_name\": \"departure_datetime\", \"description\": \"The departure time of the flight.\", \"usage\": \"Informs users and helps them plan their travel.\", \"data_type\": \"DATETIME\"}', metadata={'source': '/Users/arunpshankar/Desktop/Projects/LLM-Text-to-SQL-Architectures/data/rag-schema/columns.jsonl', 'seq_num': 55}), Document(page_content='{\"dataset_name\": \"flight_reservations\", \"table_name\": \"flights\", \"column_name\": \"origin\", \"description\": \"The departure location of the flight.\", \"usage\": \"Helps users find flights based on their travel plans.\", \"data_type\": \"STRING\"}', metadata={'source': '/Users/arunpshankar/Desktop/Projects/LLM-Text-to-SQL-Architectures/data/rag-schema/columns.jsonl', 'seq_num': 53}), Document(page_content='{\"dataset_name\": \"flight_reservations\", \"table_name\": \"flights\", \"column_name\": \"destination\", \"description\": \"The arrival location of the flight.\", \"usage\": \"Used to find flights and plan journeys.\", \"data_type\": \"STRING\"}', metadata={'source': '/Users/arunpshankar/Desktop/Projects/LLM-Text-to-SQL-Architectures/data/rag-schema/columns.jsonl', 'seq_num': 54}), Document(page_content='{\"dataset_name\": \"flight_reservations\", \"table_name\": \"transactions\", \"column_name\": \"transaction_datetime\", \"description\": \"Timestamp of when the transaction occurred.\", \"usage\": \"Used for financial records, reporting, and auditing.\", \"data_type\": \"DATETIME\"}', metadata={'source': '/Users/arunpshankar/Desktop/Projects/LLM-Text-to-SQL-Architectures/data/rag-schema/columns.jsonl', 'seq_num': 67}), Document(page_content='{\"dataset_name\": \"flight_reservations\", \"table_name\": \"flights\", \"column_name\": \"arrival_datetime\", \"description\": \"The arrival time of the flight.\", \"usage\": \"Informs users and helps them plan their travel.\", \"data_type\": \"DATETIME\"}', metadata={'source': '/Users/arunpshankar/Desktop/Projects/LLM-Text-to-SQL-Architectures/data/rag-schema/columns.jsonl', 'seq_num': 56}), Document(page_content='{\"dataset_name\": \"flight_reservations\", \"table_name\": \"customers\", \"column_name\": \"date_of_birth\", \"description\": \"The birth date of the customer.\", \"usage\": \"May be used for age verification and personalized offers.\", \"data_type\": \"DATE\"}', metadata={'source': '/Users/arunpshankar/Desktop/Projects/LLM-Text-to-SQL-Architectures/data/rag-schema/columns.jsonl', 'seq_num': 50}), Document(page_content='{\"dataset_name\": \"flight_reservations\", \"table_name\": \"customers\", \"column_name\": \"created_at\", \"description\": \"Timestamp of when the customer data was added to the database.\", \"usage\": \"Helps track customer tenure and data age.\", \"data_type\": \"DATETIME\"}', metadata={'source': '/Users/arunpshankar/Desktop/Projects/LLM-Text-to-SQL-Architectures/data/rag-schema/columns.jsonl', 'seq_num': 51}), Document(page_content='{\"dataset_name\": \"flight_reservations\", \"table_name\": \"reservations\", \"column_name\": \"status\", \"description\": \"The status of the reservation (e.g., confirmed, cancelled).\", \"usage\": \"Informs users and staff of the current state of the reservation.\", \"data_type\": \"STRING\"}', metadata={'source': '/Users/arunpshankar/Desktop/Projects/LLM-Text-to-SQL-Architectures/data/rag-schema/columns.jsonl', 'seq_num': 63}), Document(page_content='{\"dataset_name\": \"flight_reservations\", \"table_name\": \"transactions\", \"column_name\": \"reservation_id\", \"description\": \"A unique identifier for each reservation.\", \"usage\": \"Used to uniquely identify and manage reservation records.\", \"data_type\": \"INT64\"}', metadata={'source': '/Users/arunpshankar/Desktop/Projects/LLM-Text-to-SQL-Architectures/data/rag-schema/columns.jsonl', 'seq_num': 65}), Document(page_content='{\"dataset_name\": \"hotel_reservations\", \"table_name\": \"reservations\", \"column_name\": \"start_date\", \"description\": \"Indicates the beginning date of the reservation.\", \"usage\": \"Helps in determining room availability and the user\\'s stay period.\", \"data_type\": \"DATE\"}', metadata={'source': '/Users/arunpshankar/Desktop/Projects/LLM-Text-to-SQL-Architectures/data/rag-schema/columns.jsonl', 'seq_num': 8}), Document(page_content='{\"dataset_name\": \"flight_reservations\", \"table_name\": \"flights\", \"column_name\": \"carrier\", \"description\": \"The airline operating the flight.\", \"usage\": \"Provides users with the choice of airline and informs about the operator.\", \"data_type\": \"STRING\"}', metadata={'source': '/Users/arunpshankar/Desktop/Projects/LLM-Text-to-SQL-Architectures/data/rag-schema/columns.jsonl', 'seq_num': 57}), Document(page_content='{\"dataset_name\": \"flight_reservations\", \"table_name\": \"reservations\", \"column_name\": \"reservation_id\", \"description\": \"A unique identifier for each reservation.\", \"usage\": \"Used to uniquely identify and manage reservation records.\", \"data_type\": \"INT64\"}', metadata={'source': '/Users/arunpshankar/Desktop/Projects/LLM-Text-to-SQL-Architectures/data/rag-schema/columns.jsonl', 'seq_num': 59}), Document(page_content='{\"dataset_name\": \"flight_reservations\", \"table_name\": \"transactions\", \"column_name\": \"amount\", \"description\": \"The monetary value of the transaction.\", \"usage\": \"Used for accounting and financial tracking.\", \"data_type\": \"FLOAT64\"}', metadata={'source': '/Users/arunpshankar/Desktop/Projects/LLM-Text-to-SQL-Architectures/data/rag-schema/columns.jsonl', 'seq_num': 66}), Document(page_content='{\"dataset_name\": \"hotel_reservations\", \"table_name\": \"reservations\", \"column_name\": \"end_date\", \"description\": \"Marks the termination date of the reservation.\", \"usage\": \"Assists in room inventory management and billing.\", \"data_type\": \"DATE\"}', metadata={'source': '/Users/arunpshankar/Desktop/Projects/LLM-Text-to-SQL-Architectures/data/rag-schema/columns.jsonl', 'seq_num': 9}), Document(page_content='{\"dataset_name\": \"flight_reservations\", \"table_name\": \"transactions\", \"column_name\": \"transaction_id\", \"description\": \"A unique identifier for each transaction.\", \"usage\": \"Ensures each transaction is distinct and can be tracked separately.\", \"data_type\": \"INT64\"}', metadata={'source': '/Users/arunpshankar/Desktop/Projects/LLM-Text-to-SQL-Architectures/data/rag-schema/columns.jsonl', 'seq_num': 64}), Document(page_content='{\"dataset_name\": \"hotel_reservations\", \"table_name\": \"events\", \"column_name\": \"date\", \"description\": \"The date on which the event is scheduled.\", \"usage\": \"Used for event planning and scheduling.\", \"data_type\": \"STRING\"}', metadata={'source': '/Users/arunpshankar/Desktop/Projects/LLM-Text-to-SQL-Architectures/data/rag-schema/columns.jsonl', 'seq_num': 45}), Document(page_content='{\"dataset_name\": \"flight_reservations\", \"table_name\": \"flights\", \"column_name\": \"flight_id\", \"description\": \"A unique identifier for each flight.\", \"usage\": \"Used to uniquely identify and manage flight records.\", \"data_type\": \"INT64\"}', metadata={'source': '/Users/arunpshankar/Desktop/Projects/LLM-Text-to-SQL-Architectures/data/rag-schema/columns.jsonl', 'seq_num': 52}), Document(page_content='{\"dataset_name\": \"flight_reservations\", \"table_name\": \"reservations\", \"column_name\": \"flight_id\", \"description\": \"A unique identifier for each flight.\", \"usage\": \"Used to uniquely identify and manage flight records.\", \"data_type\": \"INT64\"}', metadata={'source': '/Users/arunpshankar/Desktop/Projects/LLM-Text-to-SQL-Architectures/data/rag-schema/columns.jsonl', 'seq_num': 61}), Document(page_content='{\"dataset_name\": \"flight_reservations\", \"table_name\": \"customers\", \"column_name\": \"customer_id\", \"description\": \"A unique identifier assigned to each customer.\", \"usage\": \"Ensures each customer is distinct and can be referenced in reservations.\", \"data_type\": \"INT64\"}', metadata={'source': '/Users/arunpshankar/Desktop/Projects/LLM-Text-to-SQL-Architectures/data/rag-schema/columns.jsonl', 'seq_num': 46})]\n"
     ]
    }
   ],
   "source": [
    "matched_columns = retriever.get_relevant_documents(query=query)\n",
    "logger.info(matched_columns)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[{'dataset_name': 'flight_reservations', 'table_name': 'reservations', 'column_name': 'reservation_datetime', 'description': 'Timestamp of when the reservation was made.', 'usage': 'Helps track reservation history and manage bookings.', 'data_type': 'DATETIME'}, {'dataset_name': 'flight_reservations', 'table_name': 'flights', 'column_name': 'departure_datetime', 'description': 'The departure time of the flight.', 'usage': 'Informs users and helps them plan their travel.', 'data_type': 'DATETIME'}, {'dataset_name': 'flight_reservations', 'table_name': 'flights', 'column_name': 'origin', 'description': 'The departure location of the flight.', 'usage': 'Helps users find flights based on their travel plans.', 'data_type': 'STRING'}, {'dataset_name': 'flight_reservations', 'table_name': 'flights', 'column_name': 'destination', 'description': 'The arrival location of the flight.', 'usage': 'Used to find flights and plan journeys.', 'data_type': 'STRING'}, {'dataset_name': 'flight_reservations', 'table_name': 'transactions', 'column_name': 'transaction_datetime', 'description': 'Timestamp of when the transaction occurred.', 'usage': 'Used for financial records, reporting, and auditing.', 'data_type': 'DATETIME'}, {'dataset_name': 'flight_reservations', 'table_name': 'flights', 'column_name': 'arrival_datetime', 'description': 'The arrival time of the flight.', 'usage': 'Informs users and helps them plan their travel.', 'data_type': 'DATETIME'}, {'dataset_name': 'flight_reservations', 'table_name': 'customers', 'column_name': 'date_of_birth', 'description': 'The birth date of the customer.', 'usage': 'May be used for age verification and personalized offers.', 'data_type': 'DATE'}, {'dataset_name': 'flight_reservations', 'table_name': 'customers', 'column_name': 'created_at', 'description': 'Timestamp of when the customer data was added to the database.', 'usage': 'Helps track customer tenure and data age.', 'data_type': 'DATETIME'}, {'dataset_name': 'flight_reservations', 'table_name': 'reservations', 'column_name': 'status', 'description': 'The status of the reservation (e.g., confirmed, cancelled).', 'usage': 'Informs users and staff of the current state of the reservation.', 'data_type': 'STRING'}, {'dataset_name': 'flight_reservations', 'table_name': 'transactions', 'column_name': 'reservation_id', 'description': 'A unique identifier for each reservation.', 'usage': 'Used to uniquely identify and manage reservation records.', 'data_type': 'INT64'}, {'dataset_name': 'flight_reservations', 'table_name': 'flights', 'column_name': 'carrier', 'description': 'The airline operating the flight.', 'usage': 'Provides users with the choice of airline and informs about the operator.', 'data_type': 'STRING'}, {'dataset_name': 'flight_reservations', 'table_name': 'reservations', 'column_name': 'reservation_id', 'description': 'A unique identifier for each reservation.', 'usage': 'Used to uniquely identify and manage reservation records.', 'data_type': 'INT64'}, {'dataset_name': 'flight_reservations', 'table_name': 'transactions', 'column_name': 'amount', 'description': 'The monetary value of the transaction.', 'usage': 'Used for accounting and financial tracking.', 'data_type': 'FLOAT64'}, {'dataset_name': 'flight_reservations', 'table_name': 'transactions', 'column_name': 'transaction_id', 'description': 'A unique identifier for each transaction.', 'usage': 'Ensures each transaction is distinct and can be tracked separately.', 'data_type': 'INT64'}, {'dataset_name': 'flight_reservations', 'table_name': 'flights', 'column_name': 'flight_id', 'description': 'A unique identifier for each flight.', 'usage': 'Used to uniquely identify and manage flight records.', 'data_type': 'INT64'}, {'dataset_name': 'flight_reservations', 'table_name': 'reservations', 'column_name': 'flight_id', 'description': 'A unique identifier for each flight.', 'usage': 'Used to uniquely identify and manage flight records.', 'data_type': 'INT64'}, {'dataset_name': 'flight_reservations', 'table_name': 'customers', 'column_name': 'customer_id', 'description': 'A unique identifier assigned to each customer.', 'usage': 'Ensures each customer is distinct and can be referenced in reservations.', 'data_type': 'INT64'}]\n"
     ]
    }
   ],
   "source": [
    "matched_columns_filtered = []\n",
    "\n",
    "# LangChain filters does not support multiple values at the moment\n",
    "for i, column in enumerate(matched_columns):\n",
    "    page_content = json.loads(column.page_content)\n",
    "    dataset_name = page_content['dataset_name']\n",
    "    if dataset_name == 'flight_reservations':\n",
    "        matched_columns_filtered.append(page_content)\n",
    "\n",
    "logger.info(matched_columns_filtered)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "dataset_name=flight_reservations|table_name=reservations|column_name=reservation_datetime|data_type=DATETIME\n",
      "dataset_name=flight_reservations|table_name=flights|column_name=departure_datetime|data_type=DATETIME\n",
      "dataset_name=flight_reservations|table_name=flights|column_name=origin|data_type=STRING\n",
      "dataset_name=flight_reservations|table_name=flights|column_name=destination|data_type=STRING\n",
      "dataset_name=flight_reservations|table_name=transactions|column_name=transaction_datetime|data_type=DATETIME\n",
      "dataset_name=flight_reservations|table_name=flights|column_name=arrival_datetime|data_type=DATETIME\n",
      "dataset_name=flight_reservations|table_name=customers|column_name=date_of_birth|data_type=DATE\n",
      "dataset_name=flight_reservations|table_name=customers|column_name=created_at|data_type=DATETIME\n",
      "dataset_name=flight_reservations|table_name=reservations|column_name=status|data_type=STRING\n",
      "dataset_name=flight_reservations|table_name=transactions|column_name=reservation_id|data_type=INT64\n",
      "dataset_name=flight_reservations|table_name=flights|column_name=carrier|data_type=STRING\n",
      "dataset_name=flight_reservations|table_name=reservations|column_name=reservation_id|data_type=INT64\n",
      "dataset_name=flight_reservations|table_name=transactions|column_name=amount|data_type=FLOAT64\n",
      "dataset_name=flight_reservations|table_name=transactions|column_name=transaction_id|data_type=INT64\n",
      "dataset_name=flight_reservations|table_name=flights|column_name=flight_id|data_type=INT64\n",
      "dataset_name=flight_reservations|table_name=reservations|column_name=flight_id|data_type=INT64\n",
      "dataset_name=flight_reservations|table_name=customers|column_name=customer_id|data_type=INT64\n"
     ]
    }
   ],
   "source": [
    "matched_columns_cleaned = []\n",
    "\n",
    "for doc in matched_columns_filtered:\n",
    "    dataset_name = doc['dataset_name']\n",
    "    table_name = doc['table_name']\n",
    "    column_name = doc['column_name']\n",
    "    data_type = doc['data_type']\n",
    "    matched_columns_cleaned.append(f'dataset_name={dataset_name}|table_name={table_name}|column_name={column_name}|data_type={data_type}')\n",
    "    \n",
    "matched_columns_cleaned = '\\n'.join(matched_columns_cleaned)\n",
    "logger.info(matched_columns_cleaned)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Step 5: Text-to-SQL generation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "messages = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "template = \"You are a SQL master expert capable of writing complex SQL query in BigQuery.\"\n",
    "system_message_prompt = SystemMessagePromptTemplate.from_template(template)\n",
    "messages.append(system_message_prompt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "human_template = \"\"\"Given the following inputs:\n",
    "USER_QUERY:\n",
    "--\n",
    "{query}\n",
    "--\n",
    "MATCHED_SCHEMA: \n",
    "--\n",
    "{matched_schema}\n",
    "--\n",
    "Please construct a SQL query using the MATCHED_SCHEMA and the USER_QUERY provided above. \n",
    "The goal is to determine the availability of hotels based on the provided info. \n",
    "\n",
    "IMPORTANT: Use ONLY the column names (column_name) mentioned in MATCHED_SCHEMA. DO NOT USE any other column names outside of this. \n",
    "IMPORTANT: Associate column_name mentioned in MATCHED_SCHEMA only to the table_name specified under MATCHED_SCHEMA.\n",
    "NOTE: Use SQL 'AS' statement to assign a new name temporarily to a table column or even a table wherever needed. \n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "human_message = HumanMessagePromptTemplate.from_template(human_template)\n",
    "messages.append(human_message)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "chat_prompt = ChatPromptTemplate.from_messages(messages)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "request = chat_prompt.format_prompt(query=query,\n",
    "                                    matched_schema=matched_columns_cleaned).to_messages()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "SELECT \n",
      "  r.reservation_id,\n",
      "  r.flight_id,\n",
      "  r.customer_id,\n",
      "  r.status,\n",
      "  r.reservation_datetime,\n",
      "  f.departure_datetime,\n",
      "  f.origin,\n",
      "  f.destination,\n",
      "  f.arrival_datetime,\n",
      "  f.carrier\n",
      "FROM flight_reservations.reservations AS r\n",
      "JOIN flight_reservations.flights AS f\n",
      "ON r.flight_id = f.flight_id\n",
      "WHERE r.reservation_datetime BETWEEN '2023-10-10' AND '2023-10-15';\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 40.2 ms, sys: 4.46 ms, total: 44.6 ms\n",
      "Wall time: 2.76 s\n"
     ]
    }
   ],
   "source": [
    "%%time \n",
    "\n",
    "response = llm(request)\n",
    "sql = '\\n'.join(response.content.strip().split('\\n')[1:-1])\n",
    "logger.info(sql)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Step 6: Execute the generated SQL query in BigQuery"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "bq = bigquery.Client()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>reservation_id</th>\n",
       "      <th>flight_id</th>\n",
       "      <th>customer_id</th>\n",
       "      <th>status</th>\n",
       "      <th>reservation_datetime</th>\n",
       "      <th>departure_datetime</th>\n",
       "      <th>origin</th>\n",
       "      <th>destination</th>\n",
       "      <th>arrival_datetime</th>\n",
       "      <th>carrier</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>6</td>\n",
       "      <td>6</td>\n",
       "      <td>6</td>\n",
       "      <td>Confirmed</td>\n",
       "      <td>2023-10-10 10:00:00</td>\n",
       "      <td>2023-11-25 06:00:00</td>\n",
       "      <td>SEA</td>\n",
       "      <td>JFK</td>\n",
       "      <td>2023-11-25 14:30:00</td>\n",
       "      <td>United</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>7</td>\n",
       "      <td>7</td>\n",
       "      <td>6</td>\n",
       "      <td>Confirmed</td>\n",
       "      <td>2023-10-12 11:30:00</td>\n",
       "      <td>2023-11-27 20:00:00</td>\n",
       "      <td>JFK</td>\n",
       "      <td>MIA</td>\n",
       "      <td>2023-11-27 23:30:00</td>\n",
       "      <td>American</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   reservation_id  flight_id  customer_id     status reservation_datetime  \\\n",
       "0               6          6            6  Confirmed  2023-10-10 10:00:00   \n",
       "1               7          7            6  Confirmed  2023-10-12 11:30:00   \n",
       "\n",
       "   departure_datetime origin destination    arrival_datetime   carrier  \n",
       "0 2023-11-25 06:00:00    SEA         JFK 2023-11-25 14:30:00    United  \n",
       "1 2023-11-27 20:00:00    JFK         MIA 2023-11-27 23:30:00  American  "
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = bq.query(sql).to_dataframe()\n",
    "df"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".bq_sql_agent",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.6"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
